Tutorial Notebook 10: Finetuning for Perturbation Response Prediction¶

In this tutorial, we will demonstrate how to finetune a Cell2Sentence (C2S) model for perturbation response prediction tasks. This is a critical task in single-cell analysis, where the goal is to predict how a cell's gene expression profile changes in response to a specific perturbation (e.g., a genetic knockout or a drug treatment).

We will treat this as a "translation" task in natural language: translating a cell (in cell sentence format) from its basal (control) state to its perturbed state, conditioned on the perturbation applied.

At a high level, we will:

  1. Load a public single-cell perturbation dataset.
  2. Write a custom prompt template for perturbation prediction.
  3. Subclass the PromptFormatter class to create pairs of control and perturbed cells.
  4. Finetune a pretrained C2S-Scale model on this new task.
  5. Generate a prediction with our new finetuned model to see it in action.

First, let's import the necessary libraries. We'll need standard data science libraries, single-cell analysis tools, and modules from the cell2sentence and transformers packages.

In [2]:
######## Load modules ########
from __future__ import annotations #default now for name.error issue
import os
## ensure model trains on one GPU
os.environ["CUDA_VISIBLE_DEVICES"] = "0" 
os.environ["WORLD_SIZE"] = "1"
import pickle
from datetime import datetime
import numpy as np
import random
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
SEED = 1234
random.seed(SEED)
np.random.seed(SEED)
import scipy
from tqdm import tqdm
import torch

# Cell2Sentence imports
import cell2sentence as cs
from cell2sentence.utils import benchmark_expression_conversion, reconstruct_expression_from_cell_sentence
from cell2sentence.tasks import embed_cells, predict_cell_types_of_data
from cell2sentence.prompt_formatter import get_cell_sentence_str, PromptFormatter #for custom prompt

# Hugging Face
from transformers import TrainingArguments, AutoModelForCausalLM
from datasets import Dataset # Arrow

# Single-cell libraries
import scanpy as sc
import anndata as ad
from collections import Counter, defaultdict #count table

sc.set_figure_params(dpi=300, color_map="viridis_r", facecolor="white", )
sc.settings.verbosity = 3  # verbosity: errors (0), warnings (1), info (2), hints (3)
sc.logging.print_header()
/ix/ccdg/storage3/til177/custom_miniconda/envs/cell2sentence/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
scanpy==1.9.8 anndata==0.9.2 umap==0.5.7 numpy==1.24.4 scipy==1.10.1 pandas==2.0.3 scikit-learn==1.3.2 statsmodels==0.14.1 pynndescent==0.5.13

Load Perturbation Data¶

For this tutorial, you will need a single-cell dataset that includes both control and perturbed cells. The data should be in an AnnData object. Crucially, your .obs dataframe must contain:

  • A column that distinguishes control cells from perturbed cells, e.g., adata.obs['condition']
    • Values can be something like 'control' or 'non-targeting' for control cells, and 'perturbed' or the perturbation name for the perturbed cells

For this example, we use a public genetic perturbation dataset of Jurkat cells (Nadig et al., 2025). To use your own dataset, replace DATA_PATH with the path to your preprocessed data file.

  • Paper: https://www.nature.com/articles/s41588-025-02169-3
  • They might have used only a subset given the observation number

Ensure your data is preprocessed and normalized (e.g., using log1p transformation) before proceeding.

In [62]:
# Replace this with the actual path to your dataset, if using a custom dataset
DATA_PATH = "/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/code/tutorials/data_PerturbSeq/GSE264667_jurkat.h5ad"
adata = ad.read_h5ad(DATA_PATH)
adata
Out[62]:
AnnData object with n_obs × n_vars = 262956 × 8882
    obs: 'gem_group', 'gene', 'gene_id', 'transcript', 'gene_transcript', 'sgID_AB', 'mitopercent', 'UMI_count', 'z_gemgroup_UMI'
    var: 'gene_name', 'chr', 'start', 'end', 'class', 'strand', 'length', 'in_matrix', 'mean', 'std', 'cv', 'fano'
In [63]:
adata.obs.head()
Out[63]:
gem_group gene gene_id transcript gene_transcript sgID_AB mitopercent UMI_count z_gemgroup_UMI
cell_barcode
AAACCCAAGAAACTGT-27 27 NELFE ENSG00000204356 P1P2 5601_NELFE_P1P2_ENSG00000204356 NELFE_+_31926720.23-P1P2|NELFE_-_31926676.23-P1P2 0.063665 13194.0 0.106271
AAACCCAAGAAATCCA-12 12 EMC7 ENSG00000134153 P1P2 2616_EMC7_P1P2_ENSG00000134153 EMC7_+_34394068.23-P1P2|EMC7_-_34393868.23-P1P2 0.049182 9719.0 -0.054858
AAACCCAAGAAATTCG-56 56 TAF1D ENSG00000166012 P2 8659_TAF1D_P2_ENSG00000166012 TAF1D_-_93471390.23-P2|TAF1D_+_93471338.23-P2 0.055632 11576.0 -0.138458
AAACCCAAGAAGCCAC-26 26 EIF2B2 ENSG00000119718 P1P2 2536_EIF2B2_P1P2_ENSG00000119718 EIF2B2_-_75469671.23-P1P2|EIF2B2_-_75469856.23... 0.044284 12849.0 -0.422243
AAACCCAAGACAACTA-5 5 RPP30 ENSG00000148688 P1P2 7491_RPP30_P1P2_ENSG00000148688 RPP30_+_92631924.23-P1P2|RPP30_-_92631746.23-P1P2 0.072090 11555.0 -1.806991
In [64]:
# check if match the tutorial data - YES
barcode = "AAACCCAAGCACCAGA-42"
adata[barcode].obs
Out[64]:
gem_group gene gene_id transcript gene_transcript sgID_AB mitopercent UMI_count z_gemgroup_UMI
cell_barcode
AAACCCAAGCACCAGA-42 42 EIF4B ENSG00000063046 P1P2 2562_EIF4B_P1P2_ENSG00000063046 EIF4B_+_53400192.23-P1P2|EIF4B_-_53400313.23-P1P2 0.037697 4828.0 -1.484244

Preprocessing to match the tutorial format¶

AnnData object with n_obs × n_vars = 21412 × 18080
    obs: 'batch_var', 'cell_type', 'target_gene', 'gene_id', 'mitopercent', 'UMI_count'
In [65]:
adata.obs['cell_type'] = "jurkat"
adata.obs = adata.obs[['gem_group','cell_type','gene','gene_id','mitopercent','UMI_count']]
adata.obs = adata.obs.rename(columns={'gem_group': 'batch_var', 
                                      'gene'     : 'target_gene'})
adata.obs['batch_var'] = 'jurkat'+adata.obs['batch_var'].astype(str)
adata.obs.head()
Out[65]:
batch_var cell_type target_gene gene_id mitopercent UMI_count
cell_barcode
AAACCCAAGAAACTGT-27 jurkat27 jurkat NELFE ENSG00000204356 0.063665 13194.0
AAACCCAAGAAATCCA-12 jurkat12 jurkat EMC7 ENSG00000134153 0.049182 9719.0
AAACCCAAGAAATTCG-56 jurkat56 jurkat TAF1D ENSG00000166012 0.055632 11576.0
AAACCCAAGAAGCCAC-26 jurkat26 jurkat EIF2B2 ENSG00000119718 0.044284 12849.0
AAACCCAAGACAACTA-5 jurkat5 jurkat RPP30 ENSG00000148688 0.072090 11555.0
In [66]:
# Set gene_name as the index for adata.var
adata.var_names = adata.var['gene_name'].astype(str)
adata.var = pd.DataFrame(index=adata.var_names)
adata.var_names_make_unique()
adata.var.head()
Out[66]:
gene_name
LINC01409
LINC01128
NOC2L
HES4
ISG15

Check other features before normalization¶

In [67]:
print(Counter(adata.obs["batch_var"]))

# will be used to create the __cell sentences__
print(adata.var_names)

print(adata.X.max()
Counter({'jurkat44': 5936, 'jurkat3': 5847, 'jurkat7': 5743, 'jurkat12': 5711, 'jurkat5': 5618, 'jurkat25': 5517, 'jurkat47': 5498, 'jurkat29': 5467, 'jurkat2': 5466, 'jurkat55': 5377, 'jurkat4': 5349, 'jurkat53': 5343, 'jurkat22': 5280, 'jurkat16': 5243, 'jurkat39': 5237, 'jurkat24': 5232, 'jurkat30': 5196, 'jurkat33': 5184, 'jurkat56': 5164, 'jurkat34': 5112, 'jurkat49': 5098, 'jurkat21': 5086, 'jurkat42': 5076, 'jurkat13': 5061, 'jurkat11': 5044, 'jurkat38': 5023, 'jurkat51': 4983, 'jurkat48': 4916, 'jurkat40': 4902, 'jurkat14': 4900, 'jurkat31': 4879, 'jurkat23': 4873, 'jurkat26': 4859, 'jurkat6': 4835, 'jurkat54': 4818, 'jurkat43': 4723, 'jurkat46': 4716, 'jurkat17': 4650, 'jurkat10': 4633, 'jurkat45': 4588, 'jurkat20': 4586, 'jurkat35': 4494, 'jurkat19': 4435, 'jurkat8': 4386, 'jurkat18': 4326, 'jurkat32': 4240, 'jurkat28': 4188, 'jurkat37': 4173, 'jurkat9': 4152, 'jurkat15': 4080, 'jurkat27': 4048, 'jurkat52': 3672, 'jurkat36': 3457, 'jurkat1': 1666, 'jurkat50': 870})
Index(['LINC01409', 'LINC01128', 'NOC2L', 'HES4', 'ISG15', 'TNFRSF4', 'SDF4',
       'B3GALT6', 'UBE2J2', 'ACAP3',
       ...
       'MT-ATP6', 'MT-CO3', 'MT-ND3', 'MT-ND4L', 'MT-ND4', 'MT-ND5', 'MT-ND6',
       'MT-CYB', 'AC240274.1', 'AC004556.3'],
      dtype='object', name='gene_name', length=8882)
2913.0

Preprocessing & Normalization¶

In [68]:
# basic filtering
print(adata)
AnnData object with n_obs × n_vars = 262956 × 8882
    obs: 'batch_var', 'cell_type', 'target_gene', 'gene_id', 'mitopercent', 'UMI_count'
In [69]:
# default in tutorial
sc.pp.filter_cells(adata, min_genes=200)
sc.pp.filter_genes(adata, min_cells=3)
In [70]:
print(adata) # check number decreased
AnnData object with n_obs × n_vars = 262956 × 8882
    obs: 'batch_var', 'cell_type', 'target_gene', 'gene_id', 'mitopercent', 'UMI_count', 'n_genes'
    var: 'n_cells'
In [71]:
# QC metrics
adata.var["mt"] = adata.var_names.str.startswith("MT-")
sc.pp.calculate_qc_metrics(
    adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True
)
In [72]:
print(f"Median UMI: {np.median(adata.obs['UMI_count']):.0f}")
Median UMI: 10160
In [73]:
sc.pl.violin(
    adata,
    ["n_genes_by_counts", "total_counts", "pct_counts_mt"],
    jitter=0.4,
    multi_panel=True,
)
No description has been provided for this image
In [76]:
# For Jurkat Perturb-seq with median UMI ~10k
min_umi = 1000    # ~10% of median
max_umi = 40000   # ~4x median (doublet filter)
min_genes = 500
max_genes = 6000
max_mito = 15

# Apply filters
adata = adata[
    (adata.obs['UMI_count'] > min_umi) &
    (adata.obs['UMI_count'] < max_umi) &
    (adata.obs['n_genes'] > min_genes) &
    (adata.obs['n_genes'] < max_genes) &
    (adata.obs['mitopercent'] < max_mito)
].copy()

print(adata)

sc.pl.violin(
    adata,
    ["n_genes_by_counts", "total_counts", "pct_counts_mt"],
    jitter=0.4,
    multi_panel=True,
)
AnnData object with n_obs × n_vars = 257412 × 8882
    obs: 'batch_var', 'cell_type', 'target_gene', 'gene_id', 'mitopercent', 'UMI_count', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt'
    var: 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts'
No description has been provided for this image
In [79]:
#### Normalization ####
# Count normalization
sc.pp.normalize_total(adata)
# Lop1p transformation with base 10 - base 10 is important for C2S transformation!!!
sc.pp.log1p(adata, base=10)  
# check --> ~3.4, which is expected for a base-10 log transformation.
print(adata.X.max())
normalizing counts per cell
    finished (0:00:00)
3.3689852
In [80]:
#### Visualization ####
sc.tl.pca(adata)
sc.pp.neighbors(adata)
sc.tl.umap(adata)
computing PCA
    with n_comps=50
    finished (0:00:50)
computing neighbors
    using 'X_pca' with n_pcs = 50
    finished: added to `.uns['neighbors']`
    `.obsp['distances']`, distances for each pair of neighbors
    `.obsp['connectivities']`, weighted adjacency matrix (0:00:35)
computing UMAP
    finished: added
    'X_umap', UMAP coordinates (adata.obsm) (0:03:06)
In [83]:
# adata[adata.obs['batch_var'] == 'jurkat1']

# create a folder to store plots
# TUT10_PLOT_DIR = "/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/code/tutorials/output_tut10_preprocessing_umap"
# os.makedirs(TUT10_PLOT_DIR, exist_ok=True)
# # set Scanpy figure directory
# sc.settings.figdir = TUT10_PLOT_DIR
sc.pl.umap(
    adata[adata.obs['batch_var'].isin(['jurkat1', 'jurkat2'])],
    color="batch_var",
    size=8,
    title="Jurkat Perturb-seq UMAP",
    #save="_batch_var.png",# will save in current working directory unless sc.settings.figdir is set
)
/ix/ccdg/storage3/til177/custom_miniconda/envs/cell2sentence/lib/python3.8/site-packages/scanpy/plotting/_tools/scatterplots.py:394: UserWarning: No data for colormapping provided via 'c'. Parameters 'cmap' will be ignored
No description has been provided for this image

This should be expected, given this is scRNA-seq data from cell-line with perturbation (minimal global affect).

In [84]:
# save data
SAVE_PATH = "/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/code/tutorials/data_PerturbSeq/GSE264667_jurkat_processed.h5ad"
adata.write_h5ad(SAVE_PATH)


Reload Perturbation Data¶

In [3]:
DATA_PATH = "/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/code/tutorials/data_PerturbSeq/GSE264667_jurkat_processed.h5ad"
adata = ad.read_h5ad(DATA_PATH)
adata
Out[3]:
AnnData object with n_obs × n_vars = 257412 × 8882
    obs: 'batch_var', 'cell_type', 'target_gene', 'gene_id', 'mitopercent', 'UMI_count', 'n_genes', 'n_genes_by_counts', 'total_counts', 'total_counts_mt', 'pct_counts_mt'
    var: 'n_cells', 'mt', 'n_cells_by_counts', 'mean_counts', 'pct_dropout_by_counts', 'total_counts'
    uns: 'batch_var_colors', 'log1p', 'neighbors', 'pca', 'umap'
    obsm: 'X_pca', 'X_umap'
    varm: 'PCs'
    obsp: 'connectivities', 'distances'
In [4]:
# print(Counter(adata.obs["batch_var"]))
print(adata.var_names)   # will be used to create the __cell sentences__
print(adata.X.max())     # check max value (log10 transformation expects a maximum value somewhere around 3 or 4)
print(adata.obs.columns) # check colnames (for next step)
Index(['LINC01409', 'LINC01128', 'NOC2L', 'HES4', 'ISG15', 'TNFRSF4', 'SDF4',
       'B3GALT6', 'UBE2J2', 'ACAP3',
       ...
       'MT-ATP6', 'MT-CO3', 'MT-ND3', 'MT-ND4L', 'MT-ND4', 'MT-ND5', 'MT-ND6',
       'MT-CYB', 'AC240274.1', 'AC004556.3'],
      dtype='object', name='gene_name', length=8882)
3.3689852
Index(['batch_var', 'cell_type', 'target_gene', 'gene_id', 'mitopercent',
       'UMI_count', 'n_genes', 'n_genes_by_counts', 'total_counts',
       'total_counts_mt', 'pct_counts_mt'],
      dtype='object')
In [5]:
target_gene_counter = Counter(adata.obs['target_gene'])
print(len(target_gene_counter))
2394
In [6]:
target_gene_counter.most_common(20)
Out[6]:
[('non-targeting', 11742),
 ('TFAM', 2506),
 ('SLC1A5', 1697),
 ('GFM1', 1349),
 ('GTF3C4', 1266),
 ('PSMB5', 1218),
 ('MRPL36', 1167),
 ('PPP6C', 1151),
 ('NBPF12', 1120),
 ('MRPL35', 977),
 ('POGLUT3', 970),
 ('TARDBP', 816),
 ('MRPL34', 789),
 ('CCDC6', 776),
 ('BCAR1', 766),
 ('GTF2E2', 739),
 ('GAB2', 673),
 ('TRNT1', 656),
 ('HSD17B10', 648),
 ('THAP1', 623)]

This data contains both control cells (non-targeting) as well as cells under different genetic knockouts.

In [7]:
adata.X
Out[7]:
array([[0.        , 0.        , 0.        , ..., 1.8568417 , 0.4022843 ,
        0.24614553],
       [0.        , 0.        , 0.        , ..., 1.8262414 , 0.        ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 1.8621143 , 0.27068487,
        0.27068487],
       ...,
       [0.        , 0.49683046, 0.        , ..., 1.7276423 , 0.3158951 ,
        0.        ],
       [0.        , 0.        , 0.        , ..., 1.8270676 , 0.        ,
        0.        ],
       [0.        , 0.        , 0.36475694, ..., 1.7797725 , 0.36475694,
        0.36475694]], dtype=float32)
In [8]:
print(adata.X.max())
3.3689852

The expression is already preprocessed and log1p transformed. Typically, a log1p transform with base=10 would be used for Cell2Sentence training if we wish to convert generated cell sentences back to expression vectors, since log base=10 gives a better linear relationship between log rank and log expression in many single-cell datasets.

For this tutorial, we will use this processed data as is, to train our model to generate target cell sentences.

In [9]:
# adata_subset = adata[adata.obs['batch_var'].isin(['jurkat1', 'jurkat2'])]
# print(adata_subset)
# print(adata_subset.X.max())

# target_gene_counter = Counter(adata_subset.obs['target_gene'])
# print(len(target_gene_counter))
# target_gene_counter.most_common(20)

Cell2Sentence Conversion¶

Now, we will convert the gene expression data in our AnnData object into cell sentences. This process creates a Hugging Face Arrow dataset, which is used in our LLM training.

In [9]:
# We'll keep all relevant columns for our new task
adata_obs_cols_to_keep = ['batch_var','cell_type','target_gene','gene_id','mitopercent','UMI_count']
adata_obs_cols_to_keep
Out[9]:
['batch_var',
 'cell_type',
 'target_gene',
 'gene_id',
 'mitopercent',
 'UMI_count']
In [10]:
# Create Arrow dataset and vocabulary
arrow_ds, vocabulary = cs.CSData.adata_to_arrow(
    adata=adata, 
    random_state=SEED, 
    sentence_delimiter=' ',
    label_col_names=adata_obs_cols_to_keep
)
arrow_ds
100%|██████████| 257412/257412 [01:41<00:00, 2543.44it/s]
Out[10]:
Dataset({
    features: ['cell_name', 'cell_sentence', 'batch_var', 'cell_type', 'target_gene', 'gene_id', 'mitopercent', 'UMI_count'],
    num_rows: 257412
})
In [13]:
sample_idx = 0
arrow_ds[sample_idx]
##   Check cell sentence length
print(len(arrow_ds[sample_idx]["cell_sentence"].split(" ")))
##   Check feature info
print(type(vocabulary))
print(len(vocabulary))
print(list(vocabulary.items())[:10])
4016
<class 'collections.OrderedDict'>
8882
[('LINC01409', 34288), ('LINC01128', 60189), ('NOC2L', 142586), ('HES4', 182061), ('ISG15', 227707), ('TNFRSF4', 22688), ('SDF4', 123657), ('B3GALT6', 94370), ('UBE2J2', 140413), ('ACAP3', 45948)]

Custom Prompt Formatting for Perturbation Prediction¶

This is the core of our tutorial. For this dataset, we have a single large pool of control cells (labeled 'non-targeting') and multiple groups of perturbed cells, each with a specific target_gene.

Our goal is to create training pairs where each perturbed cell is matched with a randomly selected control cell. Note that for different perturbation applications, there may be better ways of pairing control and perturbed cells.

We will define a custom prompt structure that frames our task for the LLM. The input will contain the control cell sentence and the perturbation name. The model's expected output (the response) will be the perturbed cell sentence.

First, let's define our prompt templates.

In [14]:
# The input provides the control cell and the perturbation, asking for the perturbed result.
custom_input_prompt_template = """Given the following cell sentence of {num_genes} expressed genes representing a cell's basal state, predict the cell sentence after applying the perturbation: {perturbation_name}.
Control cell sentence: {control_cell_sentence}.

Perturbed cell sentence:"""

# The answer is simply the target cell sentence.
answer_template = "{perturbed_cell_sentence}."

To apply this template, we need to create pairs of (control cell, perturbed cell) for each perturbation. We'll create a custom PerturbationPromptFormatter by subclassing the base PromptFormatter.

Our custom format_hf_ds function will:

  1. First, iterate through the entire dataset to create a list of all control cell indices.
  2. Simultaneously, it will group the indices of all perturbed cells into a dictionary, with the perturbation name (target_gene) as the key.
  3. Then, it will iterate through each perturbation group and, for every perturbed cell, randomly select a control cell from the global pool to form a pair.
  4. Finally, it will format these pairs into the input prompts and expected responses for the model.
In [15]:
# from collections import defaultdict

class PerturbationPromptFormatter(PromptFormatter):
    def __init__(self,
        task_name,
        input_prompt,
        answer_template,
        top_k_genes, 
        perturbation_col='target_gene',
        control_label='non-targeting'
    ):
        """
        Initializes the custom prompt formatter.

        Args:
            task_name (str): The name for this task.
            input_prompt (str): The template for the model's input.
            answer_template (str): The template for the model's expected response.
            top_k_genes (int): The number of top genes to include in the cell sentence.
            perturbation_col (str): The column name in the dataset that contains perturbation info.
            control_label (str): The label used to identify control cells in the perturbation_col.
        """
        super().__init__()
        self.task_name = task_name
        self.input_prompt = input_prompt
        self.answer_template = answer_template
        self.top_k_genes = top_k_genes
        self.perturbation_col = perturbation_col
        self.control_label = control_label
        assert top_k_genes > 0, "'top_k_genes' must be an integer > 0"

    def format_hf_ds(self, hf_ds):
        """
        Custom formatting function for perturbation prediction. This function creates pairs of
        (control, perturbed) cells by sampling from a global pool of control cells.
        """
        model_inputs_list = []
        responses_list = []
        
        # 1. Separate all cells into a global control pool and a dict of perturbed cells
        control_indices = []
        pert_to_indices = defaultdict(list)

        print("Grouping cells by perturbation and identifying global controls...")
        for i, sample in enumerate(hf_ds):
            if sample[self.perturbation_col] == self.control_label:
                control_indices.append(i)
            else:
                pert_to_indices[sample[self.perturbation_col]].append(i)

            # For each cell (sample) in the dataset:
            # If it's a control cell (e.g., target_gene == 'non-targeting'): add its index to control_indices
            # If it's perturbed (e.g., target_gene == 'BRCA1'): add its index to the pert_to_indices dictionary under that perturbation name
        
        assert len(control_indices) > 0, "No control cells found. Cannot create pairs."
        print(f"Found {len(control_indices)} control cells.")
        print(f"Found {len(pert_to_indices)} unique perturbations.")

        # 2. Create prompt-response pairs by iterating through perturbed cells
        print("Creating control-perturbed pairs...")
        for pert_name, perturbed_indices in tqdm(pert_to_indices.items()):
            for perturbed_idx in perturbed_indices:
                # Pair each perturbed cell with a random control cell from the global pool
                control_idx = random.choice(control_indices)
                
                control_sample = hf_ds[control_idx]
                perturbed_sample = hf_ds[perturbed_idx]

                # Format control cell sentence
                control_sentence, num_genes_str = get_cell_sentence_str(#https://github.com/vandijklab/cell2sentence/blob/a6efaf079f98491d4723ced44b929936b94368aa/src/cell2sentence/prompt_formatter.py#L31
                    control_sample,
                    num_genes=self.top_k_genes
                )
                # Format perturbed cell sentence
                perturbed_sentence, _ = get_cell_sentence_str(
                    perturbed_sample,
                    num_genes=self.top_k_genes
                )
                
                #### Matches the template fstring ####
                # Format the model input string using the perturbation name
                model_input_str = self.input_prompt.format(
                    num_genes=num_genes_str,
                    perturbation_name=pert_name,
                    control_cell_sentence=control_sentence
                )
                # Format the response string
                response_str = self.answer_template.format(
                    perturbed_cell_sentence=perturbed_sentence
                )

                model_inputs_list.append(model_input_str)
                responses_list.append(response_str)

        # Create the final Hugging Face Dataset
        ds_split_dict = {
            "sample_type": [self.task_name] * len(model_inputs_list),
            "model_input": model_inputs_list,
            "response": responses_list,
        }
        ds = Dataset.from_dict(ds_split_dict)
        return ds

Let's instantiate our custom formatter and test it on a small sample of our data to see the result.

In [16]:
task_name = "perturbation_prediction"
prompt_formatter = PerturbationPromptFormatter(
    task_name=task_name,
    input_prompt=custom_input_prompt_template,
    answer_template=answer_template,
    top_k_genes=200 # Using top 200 genes for this example. For real applications, ideal to use all nonzero expressed genes if possible.
)
In [17]:
# Format the dataset
formatted_ds = prompt_formatter.format_hf_ds(arrow_ds)
formatted_ds
Grouping cells by perturbation and identifying global controls...
Found 11742 control cells.
Found 2393 unique perturbations.
Creating control-perturbed pairs...
100%|██████████| 2393/2393 [00:58<00:00, 41.06it/s] 
Out[17]:
Dataset({
    features: ['sample_type', 'model_input', 'response'],
    num_rows: 245670
})
In [18]:
type(formatted_ds)
Out[18]:
datasets.arrow_dataset.Dataset

Note that if you want to do a train/test split of the data, separating out a split of control cells and holdout perturbed cells / entire perturbations can be done before formatting.

In [19]:
# Inspect a formatted sample
print("--- Formatted Sample ---")
print("#----Model input:----#")
print(formatted_ds[0]["model_input"], "\n")
print("#----Response:----#")
print(formatted_ds[0]["response"])
--- Formatted Sample ---
#----Model input:----#
Given the following cell sentence of 200 expressed genes representing a cell's basal state, predict the cell sentence after applying the perturbation: NELFE.
Control cell sentence: MALAT1 MT-CO3 MT-CO2 EEF1A1 RPL10 MT-ATP6 RPS2 RPS3A RPL13 GAPDH ACTB MT-CO1 RPS6 MT-CYB RPL6 PTMA RPS7 RPS23 TPT1 RPL19 RPS4X RPS19 RPL7A RACK1 MT-ND4 RPL18A RPL7 RPS12 RPS3 FTH1 RPS18 HSP90AA1 RPL3 RPS5 RPLP1 HMGB1 RPS27A RPL5 RPS14 TMSB4X RPL11 RPL9 RPS24 RPS15 RPLP0 TUBA1B RPL17 RPL18 RPL26 EEF1B2 RPS13 YBX1 RPL37 RPS16 RPL8 H3F3A RPL13A TUBB RPL15 RPS9 STMN1 RPL23A RPSA RPL28 MIF NCL NACA PPIA RPL41 NPM1 RPL21 RPL30 EEF1G RPL29 HMGN2 MT-ND1 RPS21 RPL24 RPS27 RPS8 H2AFZ ACTG1 HSP90AB1 RPLP2 RPL32 EIF4A1 HSPD1 CFL1 RPL10A HNRNPA1 HNRNPA2B1 CHCHD2 RPL35A HIST1H1D RPS15A SET ARHGDIB HNRNPD EIF1 FAU CALR BTF3 HIST1H4C RPL34 RAN RPL35 GNAS ANP32B MT-ND2 PFN1 RPL37A TYMS RPL27A GSTP1 PA2G4 SNHG29 RPL4 CD3D RPL36 TRBC1 RPL27 RPL12 NUCB2 RPL23 DUT HMGB2 RPS29 SUB1 HIST1H1B SRRM1 RPL39 HNRNPU RPS25 MT-ND3 RANBP1 RPL14 C1QBP HMGN1 SLC3A2 HINT1 PRDX1 NME2 HSPA8 EIF5A FTL CDK6 RPS28 HSPA9 CHI3L2 RPS26 LDHB CCT3 CCT6A SLC25A3 HSPE1 ARPC3 RPL22 EEF2 NDUFV2 ALYREF STIP1 CD99 HNRNPA3 RPS20 SELENOH UBA52 PABPC1 SRSF9 H3F3B ENO1 ATF4 CCT7 GADD45GIP1 PRDX5 CSNK1A1 XRCC5 LDHA B2M HIST1H1E PSMD8 TCP1 NOP53 FUS RPS17 COX4I1 SERF2 PSMA6 EIF3M RPS11 UBC H1FX SOD1 SOX4 RPL36A HNRNPAB CBX3 ATP5MC2 PSMA3 VIM CCT8.

Perturbed cell sentence: 

#----Response:----#
MT-CO3 MT-ATP6 MALAT1 MT-CO2 MT-ND4 MT-CO1 EEF1A1 RPS2 MT-CYB RPS3 PTMA RPL10 RPL13 RPS3A TMSB4X ACTB GAPDH RPS18 RPS7 RPS19 RPS12 RPS6 RPS27A MT-ND1 RPL19 RACK1 HIST1H1D RPLP1 RPS15 RPS14 RPS23 RPS5 RPL13A TPT1 RPL18 RPL3 RPLP0 RPS4X RPL7A RPL30 MT-ND3 RPL11 RPL32 RPS16 HIST1H4C RPL15 RPL6 RPL28 TUBB HSP90AB1 RPS9 RPL8 RPL18A RPL29 RPL37 RPL7 RPL26 RPL17 NPM1 RPS24 RPS15A RPS13 RPL5 MT-ND2 RPL10A RPL9 HSP90AA1 RPS8 RPSA RPL36 H2AFZ RPL34 RPS11 EEF1B2 NCL RPS27 TUBA1B RPL21 RPLP2 RPL37A MIF HSPA8 RPL24 RPS25 STMN1 RPL14 RAN GSTP1 HMGB1 RPL12 GNAS NACA FAU HNRNPA2B1 EEF1G HSPD1 FTH1 RPL35 PFN1 MT-ND5 RPL35A LDHB PPIA TYMS CHCHD2 PRDX1 SNHG29 YBX1 RPS17 RPL4 UBA52 ZFAS1 RPL23A HNRNPA1 MYO7B RPL23 BTF3 RPS28 EIF1 H3F3A RPL27A CALR RPL41 CCT2 HINT1 RPL27 ARHGDIB PCNA RPL22 RPL39 CCT6A PSMA6 EIF3A EIF4A1 CDK6 APRT CFL1 B2M RPS20 RANBP1 RPL31 RPS29 ACTG1 SNRPB NAP1L1 TRBC1 NUCB2 NME2 SRSF3 UQCRH PTGES3 ITGA4 MT-ND6 PAICS COX5A SERF2 DYNLL1 RPL36A PSMA7 CANX HMGN2 SELENOH SNHG5 HNRNPD SSBP1 RPS26 CCT3 HNRNPAB EIF5 TCP1 HMGB2 COX7C OAZ1 ADA HNRNPU HES4 CD3D PCLAF PPP1R14B COX6A1 MTDH FUS TLE5 HSPH1 UBE2I TMSB10 ATP5PF GAS5 PSMA3 EIF5A HNRNPF BRCA2 NEAT1 NHP2 DDX46 SEPTIN6 H3F3B RPS10 EIF3E SF3B1.

Now that our custom formatter is ready, we'll wrap our original arrow_ds in a CSData object. The finetune function will use this object and our custom formatter to prepare the data for training.

In [20]:
# Save directory for Huggingface dataset
c2s_save_dir = "/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/code/tutorials/data_PerturbSeq"
c2s_save_name = "jurkat_perturbation_c2s"
In [21]:
csdata = cs.CSData.csdata_from_arrow(
    arrow_dataset=arrow_ds,  # Regular cell sentence dataset put here, finetune() function will repeat the formatting with the prompt formatter
    vocabulary=vocabulary,
    save_dir=c2s_save_dir,
    save_name=c2s_save_name,
    dataset_backend="arrow"
)
print(csdata)
Saving the dataset (12/12 shards): 100%|██████████| 257412/257412 [00:15<00:00, 16490.89 examples/s]
CSData Object; Path=/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/code/tutorials/data_PerturbSeq/jurkat_perturbation_c2s, Format=arrow

Load a Pretrained Cell2Sentence Model¶

We will start with a C2S model that has already been pretrained on a diverse set of single-cell datasets. This provides a strong foundation of biological knowledge. The C2S-Scale-Pythia-1b-pt and newer C2S-Scale models are good general-purpose models to start from.

In [30]:
model_name_or_path = "/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/C2S_models/pythia-1b"
save_dir = "/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/code/tutorials/data_PerturbSeq"
save_name = "perturbation_pythia_1B"
In [31]:
csmodel = cs.CSModel(
    model_name_or_path=model_name_or_path,
    save_dir=save_dir,
    save_name=save_name
)
print(csmodel)
Using device: cpu
CSModel Object; Path=/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/code/tutorials/data_PerturbSeq/perturbation_pythia_1B

Finetune for Perturbation Prediction (Switch to py file)¶

Now, we'll finetune our model on the perturbation prediction task. We'll define our TrainingArguments and then call the finetune() method, passing in our csdata object and the PerturbationPromptFormatter instance we created.

For this tutorial, we'll run for a small number of steps (max_steps=500). For a full finetuning run, you would typically train for several epochs.

In [28]:
datetimestamp = datetime.now().strftime('%Y-%m-%d-%H_%M_%S')
output_dir = os.path.join(csmodel.save_dir, f"finetunedModel_{datetimestamp}_testFinetune_{task_name}")
print(output_dir)
/ix/ccdg/storage3/til177/ParkLab/Project_C2S_scFM/code/tutorials/data_PerturbSeq/finetunedModel_2025-12-09-14_29_09_testFinetune_perturbation_prediction
In [24]:
if not os.path.exists(output_dir):
    os.mkdir(output_dir)
In [25]:
train_args = TrainingArguments(
    bf16=True,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=4,
    learning_rate=1e-5,
    logging_steps=50,
    lr_scheduler_type="cosine",
    num_train_epochs=1, 
    eval_steps=50,
    evaluation_strategy="steps",
    save_steps=100,
    save_strategy="steps",
    output_dir=output_dir,
    max_steps=500  # Shortened for tutorial purposes
)
/gpfs/radev/home/sr2464/.conda/envs/cell2sentence2/lib/python3.8/site-packages/transformers/training_args.py:1559: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead
In [26]:
csmodel.fine_tune(
    csdata=csdata,
    task=task_name,
    train_args=train_args,
    loss_on_response_only=True, # We only want to calculate loss on the predicted sentence
    top_k_genes=200,  # Use top 200 genes for this example, normally would use full cell sentence (all nonzero expressed genes) if possible
    prompt_formatter=prompt_formatter  # Pass in our custom prompt formatter
)
Grouping cells by perturbation and identifying global controls...
Found 12013 control cells.
Found 68 unique perturbations.
Creating control-perturbed pairs...
100%|██████████| 68/68 [00:03<00:00, 20.78it/s]
Reloading model from path on disk: /home/sr2464/scratch/C2S_API_Testing/Cache_Dir/perturbation_tutorial/perturbation_pythia_1B
Map (num_proc=3):   0%|          | 0/9399 [00:00<?, ? examples/s]
/gpfs/radev/home/sr2464/Desktop/cell2sentence/src/cell2sentence/csmodel.py:210: FutureWarning: `tokenizer` is deprecated and will be removed in version 5.0.0 for `Trainer.__init__`. Use `processing_class` instead.
Detected kernel version 4.18.0, which is below the recommended minimum of 5.5.0; this can cause the process to hang. It is recommended to upgrade the kernel to the minimum version or higher.
Starting training. Output directory: /home/sr2464/scratch/C2S_API_Testing/Cache_Dir/perturbation_tutorial/2025-10-14-23_33_53_finetune_perturbation_prediction
Selecting 500 samples of eval dataset to shorten validation loop.
max_steps is given, it will override any value given in num_train_epochs
wandb: WARNING The `run_name` is currently set to the same value as `TrainingArguments.output_dir`. If this was not intended, please specify a different run name by setting the `TrainingArguments.run_name` parameter.
Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: syeda5688 (syed-a-rizvi). Use `wandb login --relogin` to force relogin
wandb version 0.22.2 is available! To upgrade, please run: $ pip install wandb --upgrade
Tracking run with wandb version 0.16.6
Run data is saved locally in /home/sr2464/Desktop/cell2sentence/tutorials/wandb/run-20251014_233441-fa2rexvg
Syncing run /home/sr2464/scratch/C2S_API_Testing/Cache_Dir/perturbation_tutorial/2025-10-14-23_33_53_finetune_perturbation_prediction to Weights & Biases (docs)

View project at https://wandb.ai/syed-a-rizvi/huggingface
View run at https://wandb.ai/syed-a-rizvi/huggingface/runs/fa2rexvg
[ 13/500 00:07 < 05:49, 1.39 it/s, Epoch 0.01/1]
Step Training Loss Validation Loss

Finetuning completed. Updated model saved to disk at: /home/sr2464/scratch/C2S_API_Testing/Cache_Dir/perturbation_tutorial/2025-10-14-23_33_53_finetune_perturbation_prediction

Generate Predictions with the Finetuned Model (Switch to c2s_tvl_11_perturbation_FT_testEval_PosteriorDist.ipynb)¶

After finetuning, let's load our new model and use it to predict the response to a perturbation for a cell from our test set.

In [27]:
final_ckpt_path = os.path.join(output_dir, "checkpoint-500")
final_ckpt_path
Out[27]:
'/home/sr2464/scratch/C2S_API_Testing/Cache_Dir/perturbation_tutorial/2025-10-14-23_33_53_finetune_perturbation_prediction/checkpoint-500'
In [28]:
save_dir
Out[28]:
'/home/sr2464/scratch/C2S_API_Testing/Cache_Dir/perturbation_tutorial'
In [29]:
# Load the finetuned model (it's automatically saved to csmodel.model_name_or_path)
finetuned_model = cs.CSModel(
    model_name_or_path=final_ckpt_path, # Path is updated after finetuning
    save_dir=save_dir,
    save_name="perturbation_predictor_finetuned_final"
)
Using device: cuda
In [30]:
finetuned_model.save_path
Out[30]:
'/home/sr2464/scratch/C2S_API_Testing/Cache_Dir/perturbation_tutorial/perturbation_predictor_finetuned_final'
In [31]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device
Out[31]:
'cuda'
In [32]:
final_model = AutoModelForCausalLM.from_pretrained(
    finetuned_model.save_path,
    cache_dir=os.path.join(csmodel.save_dir, ".cache"),
    trust_remote_code=True
).to(device)
In [33]:
# Load dataset split done in finetune() function, saved to output directory
with open(os.path.join(output_dir, 'data_split_indices_dict.pkl'), 'rb') as f:
    data_split_indices_dict = pickle.load(f)

data_split_indices_dict.keys()
Out[33]:
dict_keys(['train', 'val', 'test'])
In [34]:
# Print a few indices of test samples
data_split_indices_dict['test'][:10]
Out[34]:
[7, 29, 33, 35, 51, 54, 65, 70, 115, 116]
In [35]:
# Select a few unseen samples
formatted_test_ds = formatted_ds.select(data_split_indices_dict['test'][:10])
formatted_test_ds
Out[35]:
Dataset({
    features: ['sample_type', 'model_input', 'response'],
    num_rows: 10
})
In [36]:
# Select a sample from the test set for inference
sample_idx = 0
inference_prompt = formatted_test_ds[sample_idx]['model_input']
ground_truth_response = formatted_test_ds[sample_idx]['response']

print("--- Inference Prompt ---")
print(inference_prompt)
--- Inference Prompt ---
Given the following cell sentence of 200 expressed genes representing a cell's basal state, predict the cell sentence after applying the perturbation: EIF4B.
Control cell sentence: ACTB MT-ATP6 MT-CO3 MT-CO2 TMSB4X TUBA1B MT-ND4 HSP90AA1 HIST1H4C TMSB10 MT-CYB ACTG1 TUBB RACK1 PFN1 H3F3A ARHGDIB YBX1 MT-ND1 CFL1 H2AFZ B2M FTH1 HSP90AB1 CENPF UBA52 MIF MT-ND3 LDHB NCL PPIA SERF2 HNRNPA2B1 EEF1B2 CDK6 PSMA7 HSPD1 TOP2A KPNA2 GSTP1 BTF3 MT-ND2 CD3D SET CORO1A UBB ANP32B HNRNPU STMN1 EIF4A1 CCT3 EEF1G EEF2 SOX4 EIF1 BANF1 HNRNPD ADA COX4I1 MKI67 COX6C SUMO2 LCP1 SLC25A3 PKM CHCHD2 HINT1 OAZ1 NDUFS6 ENO1 TUBB4B HMGB2 MACF1 NME1 LSM4 SELENOH HSPE1 ERH NUCKS1 HMGN1 DDX5 NDUFB10 HIST1H1E GPX4 NDUFA13 GTF3A MT-ND6 PTGES3 ARL6IP1 RBMX PGK1 CTCF ATP5MC3 PSMB2 SLC25A5 PRDX5 COX7C SPTBN1 BPTF ANXA1 UBE2C PRDX1 SRSF10 NDUFA12 SFPQ EIF3G TOMM22 SIVA1 PHB2 CCT2 PA2G4 PCBP2 SUB1 CCT6A RHOA ASPM YWHAE ATP5MC1 XRCC5 SMC4 SRSF3 YBX3 HNRNPAB ATP5F1B CSDE1 SNRPB IFI16 SNRPD1 PAFAH1B3 MSN XRCC6 COX6A1 COX8A ATP5F1E ALYREF CBX3 MYL6 EIF3E RBM39 RAC1 UBE2D2 C1QBP KHDRBS1 CCT4 JPT1 RAD21 NME2 PSME2 SEPTIN7 SNU13 CKLF NMRAL1 NOP53 EIF3I EIF3D SAP18 SON HTATSF1 PSMA3 SEC61B AIP RBM8A C11ORF58 PRRC2C OSTC NDUFAB1 NDUFA4 SELENOW HNRNPA3 SRSF7 H3F3B TOMM20 SRSF9 HNRNPK CCT8 ARPC2 ATP5F1A ATP5F1C PHB EIF4B NUCB2 SSBP1 CD3G TMEM50A HNRNPM BUB3 SRSF2 PSMC5 KDM5A SNRNP25 PAK2 MYL12A NIFK SNRPF SLC25A6 KIF14 PRMT1 MTDH SRRM2 ATP5PF.

Perturbed cell sentence:
In [37]:
# Generate the prediction
predicted_response = finetuned_model.generate_from_prompt(
    model=final_model,
    prompt=inference_prompt,
    max_num_tokens=800 # max number of tokens to generate, ~4 tokens per gene
)
In [38]:
print("\n--- Ground Truth Perturbed Cell ---")
print(ground_truth_response)
print("\n--- Predicted Perturbed Cell ---")
print(predicted_response)
--- Ground Truth Perturbed Cell ---
MT-ATP6 MT-CO3 MT-CO2 MT-ND4 MT-CYB MT-ND1 MT-ND2 ACTB HSP90AA1 TMSB4X YBX1 EEF1B2 MT-ND3 RACK1 HSP90AB1 EEF1G MIF NME2 HIST1H4C TUBA1B NCL TUBB ADA ENO1 STMN1 H2AFZ PFN1 H3F3A CFL1 LDHB HINT1 HSPD1 C1QBP HSPE1 UBA52 SERF2 ACTG1 PPIA B2M CALR HNRNPA2B1 ARHGDIB GSTP1 SET MT-ND5 BTF3 CCT2 CHCHD2 NUCB2 XRCC5 PGK1 CD3D HNRNPU SUMO2 PPP1R14B HNRNPD UQCRH FDPS ALYREF SIVA1 DNAJA1 SLC25A6 ARPC2 FTL TYMS DUT COX4I1 SNRPB DDX5 PRKDC SLC25A3 PSMA7 CD3G SLC25A5 UBC ATP5F1E MYL6 ATP5F1B CCT6A PCLAF CDK6 H3F3B EIF4A1 EEF2 HMGB2 GUK1 THRAP3 HNRNPDL SERBP1 FABP5 EIF2S2 NUCKS1 HNRNPA3 HMGN1 COX7C SFPQ NDUFA13 HSPA8 OAZ1 MARCKSL1 DEK SELENOH FTH1 SRM SNRPF EIF1 DYNLL1 XRCC6 MT-ND6 RSL1D1 ANP32B EIF3E PSMA3 NDUFS5 ERH ATP5MC3 ARPC3 GLUL YBX3 FUS SNRPD1 YWHAB ATP5MC1 PRDX1 PSMA4 SOX4 NDUFAB1 PPA1 EIF5 NUDC STOML2 NME4 SRSF7 MZB1 H1FX NOP56 TPM4 NME1 NASP EIF3I ATP6V1F SF3B2 TRIR CCNI ARPC1B COX6C SLIRP SNRPC BRK1 ARPC5 ATP5MF SRSF3 CD7 HNRNPR GNAI2 CARHSP1 PPIG ATP5MG PTGES3 PFDN2 SSBP1 C9ORF16 PAICS GPX4 UBB C12ORF57 HIST1H1B ANP32A PRDX5 APRT PKM PFDN5 NHP2 SUB1 CORO1A LSM2 HNRNPC RRM2 PNN EPRS SNRPA1 HSP90B1 MCM3 CCT8 PRRC2C SKP1 RNASEH2B CIAO2B PRMT1 RAC1 SRRT HNRNPF VDAC3 ISG15 NAA10 RRP1B CCT3 MYL12B NAE1 EMC6.

--- Predicted Perturbed Cell ---
MT-ATP6 MT-CO3 MT-CO2 MT-CYB MT-ND4 TMSB4X ACTB MT-ND1 MT-ND3 MT-ND2 HSP90AA1 HSP90AB1 RACK1 YBX1 MIF TUBA1B H3F3A H2AFZ HSPD1 NCL EEF1B2 FTH1 HNRNPA2B1 EEF1G LDHB HINT1 CHCHD2 NME2 BTF3 UBA52 EIF4A1 ACTG1 STMN1 TUBB PFN1 CFL1 EIF1 PPIA GSTP1 PRDX1 HNRNPU SET B2M HSPE1 ARHGDIB ENO1 SERF2 CD3D MT-ND5 HNRNPD HNRNPC C1QBP FTL SLC25A3 COX4I1 SLC25A5 ATP5MC3 SRSF3 PPP1R14B SIVA1 HNRNPA3 CCT3 ANP32B SFPQ EIF3E HNRNPR PSMB1 ATP5MC2 CCT6A PSMB2 NUCB2 PSMA7 PSMB3 SNRPD1 PSMB6 ATP5F1B PPA1 UQCRH SNRPB NME1 SRSF7 PSMB7 NDUFS5 HNRNPK OAZ1 PFDN5 TMSB10 ADA GTF3A ATP5F1E SERBP1 COX7C NDUFA13 PSMB5 CCT2 ATP5F1D GUK1 EIF3A NDUFB10 FABP5 PSMB3 NDUFA4 ATP5F1A NUDC POMP HNRNPDL UBB PSMB6 NDUFS6 COX6A1 SRSF2 EIF3M SNRPE PSMA4 PSMD7 GADD45GIP1 ATP5MG PRMT1 NOP53 PSMD2 UQCRB YWHAE HNRNPM PSMD1 NDUFB11 EIF3H TOMM6 PSMD11 PSMC5 SRSF9 PSMC3 NDUFA11 NOP56 EIF3I EIF3G ATP5PO SNRPF RBM8A PSMD4 TOMM22 RTRAF SRSF11 PSMC1 NDUFAB1 HNRNPAB PSMD13 PSMA2 NUCKS1 UQCR10 RSL1D1 PSMB4 TOMM5 GYPC PFDN2 PSMA3 NDUFA12 NOP10 EIF4G2 SELENOH UBE2I RBM3 PSMD12 PSMC4 HNRNPD SRM PSMB8 EIF3L RBMX PSMD6 SLC25A6 TOMM20 UBE2L3 PSMB5 RBM25 EIF3D TMA7 PNN RBM39 UQCR11 GTF3C6 PPP1CA EIF2S2 PSMC2 NUDT8 PSMD11 RBM8B RBM3B TOMM40 RAC1 TUFM EIF4B PSMB7 RSL24D1 GTF2A2 NOL7 PSMD12.

Conclusion¶

In this tutorial, you learned how to finetune a Cell2Sentence model for a custom task: perturbation response prediction. By creating a PerturbationPromptFormatter, we were able to structure our data into control-perturbation-response triplets, enabling the model to learn the complex transcriptional changes that occur upon perturbation.

This approach is highly flexible and can be adapted to various experimental designs. The finetuned model can now be used for in-silico experiments, such as virtual screening of genetic perturbations or predicting the effect of new compounds, significantly accelerating the pace of biological discovery.

In [ ]: